#include <cstdio>
#include <memory.h>
#include <vector>
#include <string>
#include <cassert>
#include <algorithm>
#include <iterator>
#include <unordered_set>
#include <cmath>
#include <omp.h>
#include <atomic>
#include "chesspi.h"

int max_depth = 5;
int max_nodes = 1000*1000*32;

//棋子代价
static const unsigned int table_cost[16] = {100000,150,150,150,150,150,150,2000,2000,150,150,100,100,100,100,100};
/*!
 * \brief calc_cost 计算走位后，idx棋子被击杀的代价，并返回
 * \param coordx 当前各个棋子的X坐标
 * \param coordy 当前各个棋子的Y坐标
 * \param alive  当前各个棋子的存活标记
 * \param killed 总击杀数
 * \param idx    被击杀的棋子
 * \return 击杀代价
 */
float calc_cost(const int coordx[/*32*/], const int coordy[/*32*/],const int alive[/*32*/],const int killed,const int idx)
{
	//	 * 帅士士相相马马车车炮炮兵兵兵兵兵  將仕仕象象馬馬車車砲砲卒卒卒卒卒
	assert(idx >= 16);
	unsigned int rescost = table_cost[idx%16];
	//位置加权
	switch (idx % 16) {
	//相士双全时击杀价值高
	case 1:
	case 2:
	case 3:
	case 4:
		if (alive[(idx-1)/2*2+1] || alive [(idx-1)/2*2+2])
			rescost *= 2;
		break;
		//马战线挺进，以及后期击杀价值高
	case 5:
	case 6:
		rescost *= 1+((11-coordy[idx])/3.0);
		rescost *= 1 + killed / 10.0;
		break;
		//车战线击杀高
	case 7:
	case 8:
		rescost *= 1+((11-coordy[idx])/3.0);
		break;
		//炮前期击杀高
	case 9:
	case 10:
		rescost *= 1+((11-coordy[idx])/3.0);
		rescost *= 1 + (32 - killed) / 10.0;
		break;
		//卒过河击杀高
	case 11:
	case 12:
	case 13:
	case 14:
	case 15:
		rescost *= coordy[idx]<6?4:1;
		//当头卒价值高
		if (coordx[idx]==5)
			rescost *=5;
		break;
	default:
		break;
	}
	return rescost;
}

std::vector<chess_node> build_tree(const chess_node & root, const int side,const std::vector<chess_node> & history)
{
	std::vector<chess_node> tree;
	std::unordered_set <size_t> dict;
	for (const chess_node & n: history)
		dict.insert(node2hash(n.coords,n.alive));
	tree.push_back(root);
	tree[0].side = side % 2;
	tree[0].depth = 0;
	size_t curr_i = 0;
	//要停留在敌走的偶数步
	const int stop_depth = (max_depth+1)/2 * 2;
	printf ("Max Nodes = %d\n",max_nodes);
	while (tree.size()<=max_nodes && curr_i<tree.size())
	{
		const size_t ts = tree.size();
		const int cores = omp_get_num_procs();
		std::vector<std::vector<chess_node> > vec_appends;
		for (int i=0;i<cores;++i)
			vec_appends.push_back(std::vector<chess_node>());
		std::atomic<int> new_appends (0);

#pragma omp parallel for
		for (int i=curr_i;i<ts;++i)
		{
			if (new_appends + ts >=max_nodes)
				continue;
			const unsigned char clock = tree[i].depth;
			if (clock >= stop_depth)
				continue;
			bool onlykill = clock >=max_depth;
			const int tid = omp_get_thread_num();
			if ((tree[i].alive & 0x00010001)==0x00010001)
			{
				const int curr_side = (side + clock) % 2;
				std::vector<chess_node> next_status =
						expand_node(tree[i],curr_side,onlykill);

				const size_t sz = next_status.size();
				for (size_t j=0;j<sz;++j)
				{
					size_t ha = node2hash(next_status[j].coords,next_status[j].alive);
					bool needI = false;
#pragma omp critical
					{
						if (dict.find(ha)==dict.end() && clock+1 <= max_depth)
						{
							needI = true;
							dict.insert(ha);
						}
						else if (dict.find(ha)==dict.end() && clock + 1 <= stop_depth)
						{
							if (next_status[j].jump_cost[0]+next_status[j].jump_cost[1]>0)
							{
								needI = true;
								dict.insert(ha);
							}
						}
					}
					if (needI)
					{
						next_status[j].parent = i;
						next_status[j].side = curr_side;
						next_status[j].depth = clock+1;
						vec_appends[tid].push_back(next_status[j]);
						//++tree[i].leaves;
						++new_appends;
						if (new_appends%1000==0)
						{
							printf ("Thinking.%d:%d...  \r",i,int(new_appends+ts));
						}
					}
				}
			}
		}
		for (int i=0;i<cores;++i)
		{
			if (vec_appends[i].size())
				std::move(vec_appends[i].begin(),vec_appends[i].end(),std::back_inserter(tree));
		}
		curr_i += (ts - curr_i);
	}

	printf ("\nDepth = %d                \n",tree.rbegin()->depth);

	return tree;
}

size_t judge_tree(std::vector<chess_node> & tree)
{
	const size_t total_nodes = tree.size();
	if (total_nodes<2)
		return 0;
	int side = tree[0].side;
	size_t i = total_nodes - 1;
	while (i > 0)
	{
		if (tree[i].side==0)
		{
			float ratio = sqrt((tree[i].jump_cost[1]+1) / (tree[i].jump_cost[0]+1)/ (tree[i].jump_cost[0]+1));
			tree[i].weight = ratio;
		}
		else
		{
			float ratio = sqrt((tree[i].jump_cost[0]+1) / (tree[i].jump_cost[1]+1)/ (tree[i].jump_cost[1]+1));
			tree[i].weight = ratio;
		}
		size_t parent = tree[i].parent;
		tree[parent].jump_cost[0] += tree[i].jump_cost[0] * tree[i].weight/tree[i].depth/tree[i].depth;
		tree[parent].jump_cost[1] += tree[i].jump_cost[1] * tree[i].weight/tree[i].depth/tree[i].depth;
		--i;
		if (i%1000==0)
			printf ("Sorting.%d...  \r",total_nodes - i);
	}

	size_t p = 1;
	float max_v = 0;
	int max_p = 1;
	while (p<total_nodes)
	{
		if (tree[p].parent)
			break;
		//float v = (tree[p].jump_cost[1-side]+1)/(tree[p].jump_cost[side]);
		float v = (tree[p].weight);
		if (v > max_v)
		{
			max_v = v;
			max_p = p;
		}
		++p;
	}
	return max_p;
}
